import cv2
from matplotlib.widgets import Slider
import matplotlib.pyplot as plt
import cupy as cp
from tqdm.rich import tqdm

def gpu_fourier_denoise(image, cutoff_frequency):
    # 将图像转换为频率域
    image_gpu = cp.array(image)
    f = cp.fft.fft2(image_gpu)
    fshift = cp.fft.fftshift(f)

    # 创建一个与图像大小相同的掩码，中心为低通滤波区域
    rows, cols = image.shape
    crow, ccol = rows // 2, cols // 2
    mask = cp.zeros((rows, cols), cp.uint8)
    mask[
        crow - cutoff_frequency : crow + cutoff_frequency,
        ccol - cutoff_frequency : ccol + cutoff_frequency,
    ] = 1

    # 应用低通滤波器
    fshift_filtered = fshift * mask

    # 将过滤后的频率域图像转换回空间域
    f_ishift = cp.fft.ifftshift(fshift_filtered)
    img_back = cp.fft.ifft2(f_ishift)

    return img_back.real


def gpu_denoise_image(image, cutoff_frequency):
    # 分别对每个通道进行去噪
    denoised_channels = []
    for i in range(3):
        channel = image[:, :, i]
        denoised_channel = gpu_fourier_denoise(channel, cutoff_frequency)
        denoised_channels.append(denoised_channel)

    # 合并去噪后的通道
    denoised_img = cp.stack(denoised_channels, axis=2)

    # 将结果转换为合适的数据类型和范围
    denoised_img = cp.clip(denoised_img, 0, 255).astype(cp.uint8)

    return denoised_img


# 读取图像
image_path = "image.png"  # 替换为您的图片路径
original_img = cv2.imread(image_path, cv2.IMREAD_COLOR)
# 将原始图像数据转换为cupy.ndarray
original_img_gpu = cp.array(original_img)
# 创建一个窗口来显示图像和滑块
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.25)
denoised_img = gpu_denoise_image(original_img_gpu, 30)  # 初始去噪
ax.imshow(cp.asnumpy(denoised_img))  # 显示图像
# 创建一个滑块来调整截止频率
axcolor = "lightgoldenrodyellow"
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)
slider = Slider(axfreq, "Cutoff Frequency", 1, 100, valinit=30, valstep=1)


# 注册滑块事件的回调函数
def update(val):
    cutoff_frequency = slider.val
    denoised_img = gpu_denoise_image(original_img_gpu, cutoff_frequency)
    ax.imshow(cp.asnumpy(denoised_img))  # 更新图像显示
    fig.canvas.draw_idle()


slider.on_changed(update)
# 自动调整截止频率并显示图像
for cutoff_frequency in tqdm(range(1, 1001)):
    denoised_img = gpu_denoise_image(original_img_gpu, cutoff_frequency)
    ax.imshow(cp.asnumpy(denoised_img))  # 更新图像显示
    fig.canvas.draw_idle()
    # plt.pause(0.01)  # 暂停一下，以便观察每个频率下的变化
